{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is intended to demonstrate how select registration, segmentation, and image mathematical methods of ITKTubeTK can be combined to perform multi-channel brain extraction (aka. skull stripping for patient data containing multiple MRI sequences).\n",
    "\n",
    "There are many other (probably more effective) brain extraction methods available as open-source software such as BET and BET2 in the FSL package (albeit such methods are only for single channel data).   If you need to perform brain extraction for a large collection of scans that do not contain major pathologies, please use one of those packages.   This notebook is meant to show off the capabilities of specific ITKTubeTK methods, not to demonstration how to \"solve\" brain extraction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itk\n",
    "from itk import TubeTK as ttk\n",
    "\n",
    "from itkwidgets import view\n",
    "\n",
    "import numpy as np\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "ImageType = itk.Image[itk.F, 3]\n",
    "\n",
    "ReaderType = itk.ImageFileReader[ImageType]\n",
    "\n",
    "InputBaseName = \"../Data/CTP/CTP32\"\n",
    "\n",
    "filename = InputBaseName + \".mha\"\n",
    "reader1 = ReaderType.New(FileName=filename)\n",
    "reader1.Update()\n",
    "im1 = reader1.GetOutput()\n",
    "\n",
    "resamp = ttk.ResampleImage[ImageType].New(Input = im1)\n",
    "resamp.SetMakeHighResIso(True)\n",
    "resamp.Update()\n",
    "im1iso = resamp.GetOutput()\n",
    "\n",
    "filename = InputBaseName + \"-Iso.mha\"\n",
    "itk.imwrite(im1iso, filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 8\n",
    "readerList = [\"003\", \"010\", \"026\", \"034\", \"045\", \"056\", \"063\", \"071\"]\n",
    "\n",
    "imBase = []\n",
    "imBaseB = []\n",
    "for i in range(0,N):\n",
    "    name = \"../Data/MRI-Normals/Normal\"+readerList[i]+\"-FLASH.mha\"\n",
    "    nameB = \"../Data/MRI-Normals/Normal\"+readerList[i]+\"-FLASH-Brain.mha\"\n",
    "    reader = ReaderType.New(FileName=name)\n",
    "    reader.Update()\n",
    "    imMathNorm = ttk.ImageMath.New(Input=reader.GetOutput())\n",
    "    imMathNorm.NormalizeMeanStdDev()\n",
    "    imBaseTmp = imMathNorm.GetOutput()\n",
    "    reader = ReaderType.New(FileName=nameB)\n",
    "    reader.Update()\n",
    "    imBaseBTmp = reader.GetOutput()\n",
    "    imBase.append(imBaseTmp)\n",
    "    imBaseB.append(imBaseBTmp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "14f25fe5c469407aba985daf97397866",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "view(im1iso)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2d08565fa56e46d490abaf32653c9f67",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "thresh = ttk.ImageMath.New(Input=im1iso)\n",
    "thresh.ReplaceValuesOutsideMaskRange(im1iso,-150,600,-200) #-150,100,-200 for min\n",
    "thresh.Threshold(-199,600,1,0)\n",
    "#thresh.NormalizeMeanStdDev()\n",
    "im1isoT = thresh.GetOutput()\n",
    "im1iso = im1isoT\n",
    "view(im1iso)\n",
    "#itk.imwrite(im1iso,\"im1iso.mha\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "maskMath = ttk.ImageMath.New(Input=im1iso)\n",
    "#maskMath.Threshold(-2,2,1,0)\n",
    "maskMath.Dilate(20,1,0)\n",
    "maskMathD = maskMath.GetOutput()\n",
    "maskMath.SetInput(im1iso)\n",
    "maskMath.Erode(20,1,0)\n",
    "maskMath.AddImages(maskMathD,-1,1)\n",
    "mask = maskMath.GetOutput()\n",
    "#itk.imwrite(mask, \"mask.mha\")\n",
    "maskObject = itk.ImageSpatialObject[3,itk.F].New(Image=mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "maskMath.SetInput(im1iso)\n",
    "maskMath.Blur(3)\n",
    "im1isoBlur = maskMath.GetOutput()\n",
    "#itk.imwrite(im1isoBlur,\"im1isoBlur.mha\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "*** 12% : 129s ***\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-13-9b541ca871fe>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     24\u001b[0m     \u001b[1;31m#regBTo1.SetMetric(\"NORMALIZED_CORRELATION_METRIC\")\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     25\u001b[0m     \u001b[0mregBTo1\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mUpdate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 26\u001b[1;33m     \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mregBTo1\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mResampleImage\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     27\u001b[0m     \u001b[0mregB\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m \u001b[0mimg\u001b[0m \u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     28\u001b[0m     \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mregBTo1\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mResampleImage\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"LINEAR\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mimBaseB\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "RegisterImagesType = ttk.RegisterImages[ImageType]\n",
    "regB = []\n",
    "regBB = []\n",
    "for i in range(0,N):\n",
    "    start = time.time()\n",
    "    maskMath.SetInput(imBase[i])\n",
    "    maskMath.Blur(3)\n",
    "    movingIm = maskMath.GetOutput()\n",
    "    #itk.imwrite(movingIm,\"movingIm.mha\")\n",
    "    regBTo1 = RegisterImagesType.New(FixedImage=im1isoBlur, MovingImage=movingIm)\n",
    "    regBTo1.SetReportProgress(True)\n",
    "    regBTo1.SetExpectedOffsetMagnitude(40)\n",
    "    regBTo1.SetExpectedRotationMagnitude(0.01)\n",
    "    regBTo1.SetExpectedScaleMagnitude(0.1)\n",
    "    regBTo1.SetRigidMaxIterations(500)\n",
    "    regBTo1.SetAffineMaxIterations(500)\n",
    "    regBTo1.SetRigidSamplingRatio(0.1)\n",
    "    regBTo1.SetAffineSamplingRatio(0.1)\n",
    "    regBTo1.SetInitialMethodEnum(\"INIT_WITH_IMAGE_CENTERS\")\n",
    "    regBTo1.SetFixedImageMaskObject(maskObject)\n",
    "    regBTo1.SetUseFixedImageMaskObject(True)\n",
    "    regBTo1.SetRegistration(\"PIPELINE_AFFINE\")\n",
    "    regBTo1.SetMetric(\"MATTES_MI_METRIC\")\n",
    "    #regBTo1.SetMetric(\"NORMALIZED_CORRELATION_METRIC\")\n",
    "    regBTo1.Update()\n",
    "    img = regBTo1.ResampleImage()\n",
    "    regB.append( img )\n",
    "    img = regBTo1.ResampleImage(\"LINEAR\", imBaseB[i])\n",
    "    regBB.append( img )\n",
    "    end = time.time()\n",
    "    print('*** ' + str(round((i+1/N)*100)) + '% : ' + str(round((end-start))) + 's ***')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "94288b2ee03b42309120b15a3fe06bba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "regBBT = []\n",
    "for i in range(0,N):\n",
    "    imMath = ttk.ImageMath[ImageType,ImageType].New( Input=regBB[i] )\n",
    "    imMath.Threshold(0,1,0,1)\n",
    "    img = imMath.GetOutput()\n",
    "    if i==0:\n",
    "        imMathSum = ttk.ImageMath[ImageType,ImageType].New( img )\n",
    "        imMathSum.AddImages( img, 1.0/N, 0 )\n",
    "        sumBBT = imMathSum.GetOutput()\n",
    "    else:\n",
    "        imMathSum = ttk.ImageMath[ImageType,ImageType].New( sumBBT )\n",
    "        imMathSum.AddImages( img, 1, 1.0/N )\n",
    "        sumBBT = imMathSum.GetOutput()\n",
    "        \n",
    "view(sumBBT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "insideMath = ttk.ImageMath[ImageType,ImageType].New( Input = sumBBT )\n",
    "insideMath.Threshold(0,0.9,0,1)\n",
    "insideMath.Dilate(5,1,0)\n",
    "insideMath.Erode(25,1,0)\n",
    "brainInside = insideMath.GetOutput()\n",
    "\n",
    "outsideMath = ttk.ImageMath[ImageType,ImageType].New( Input = sumBBT )\n",
    "outsideMath.Threshold(0,0,1,0)\n",
    "outsideMath.Erode(1,1,0)\n",
    "brainOutsideAll = outsideMath.GetOutput()\n",
    "outsideMath.Erode(20,1,0)\n",
    "outsideMath.AddImages(brainOutsideAll, -1, 1)\n",
    "brainOutside = outsideMath.GetOutput()\n",
    "\n",
    "outsideMath.AddImages(brainInside,1,2)\n",
    "brainCombinedMask = outsideMath.GetOutputUChar()\n",
    "\n",
    "outsideMath.AddImages(im1iso, 10, 1)\n",
    "brainCombinedMaskView = outsideMath.GetOutput()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "61c74586334348e7831a5bb3198759df",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "view(brainCombinedMaskView)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "LabelMapType = itk.Image[itk.UC,3]\n",
    "\n",
    "segmenter = ttk.SegmentConnectedComponentsUsingParzenPDFs[ImageType,LabelMapType].New()\n",
    "segmenter.SetFeatureImage( im1iso )\n",
    "segmenter.SetInputLabelMap( brainCombinedMask )\n",
    "segmenter.SetObjectId( 2 )\n",
    "segmenter.AddObjectId( 1 )\n",
    "segmenter.SetVoidId( 0 )\n",
    "segmenter.SetErodeDilateRadius( 15 )\n",
    "segmenter.SetHoleFillIterations( 40 )\n",
    "segmenter.Update()\n",
    "segmenter.ClassifyImages()\n",
    "brainCombinedMaskClassified = segmenter.GetOutputLabelMap()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b3a4c7e4620047399b9fa52b1071c7a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageUC3; pr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "view(brainCombinedMaskClassified)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "cast = itk.CastImageFilter[LabelMapType, ImageType].New()\n",
    "cast.SetInput(brainCombinedMaskClassified)\n",
    "cast.Update()\n",
    "brainMaskF = cast.GetOutput()\n",
    "\n",
    "brainMath = ttk.ImageMath[ImageType,ImageType].New(Input = brainMaskF)\n",
    "brainMath.Threshold(2,2,1,0)\n",
    "#brainMath.Dilate(1,1,0)\n",
    "brainMaskD = brainMath.GetOutput()\n",
    "brainMath.SetInput( im1iso )\n",
    "brainMath.ReplaceValuesOutsideMaskRange( brainMaskD, 1, 1, 0)\n",
    "brain = brainMath.GetOutput()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c30d2e98d677480cbbc2f7c96e450e33",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageF3; pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "view(brain)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "writer = itk.ImageFileWriter[ImageType].New(Input = brain)\n",
    "filename = InputBaseName + \"-Brain.mha\"\n",
    "writer.SetFileName(filename)\n",
    "writer.Update()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
